{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transformers 설치 방법\n",
    "! pip install transformers datasets evaluate accelerate\n",
    "# 마지막 릴리스 대신 소스에서 설치하려면, 위 명령을 주석으로 바꾸고 아래 명령을 해제하세요.\n",
    "# ! pip install git+https://github.com/huggingface/transformers.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 이미지 분류[[image-classification]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "hide_input": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/tjAIM7BOYhw?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#@title\n",
    "from IPython.display import HTML\n",
    "\n",
    "HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/tjAIM7BOYhw?rel=0&amp;controls=0&amp;showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이미지 분류는 이미지에 레이블 또는 클래스를 할당합니다. 텍스트 또는 오디오 분류와 달리 입력은\n",
    "이미지를 구성하는 픽셀 값입니다. 이미지 분류에는 자연재해 후 피해 감지, 농작물 건강 모니터링, 의료 이미지에서 질병의 징후 검사 지원 등\n",
    "다양한 응용 사례가 있습니다.\n",
    "\n",
    "이 가이드에서는 다음을 설명합니다:\n",
    "\n",
    "1. [Food-101](https://huggingface.co/datasets/ethz/food101) 데이터 세트에서 [ViT](https://huggingface.co/docs/transformers/main/ko/tasks/model_doc/vit)를 미세 조정하여 이미지에서 식품 항목을 분류합니다.\n",
    "2. 추론을 위해 미세 조정 모델을 사용합니다.\n",
    "\n",
    "<Tip>\n",
    "\n",
    "이 작업과 호환되는 모든 아키텍처와 체크포인트를 보려면 [작업 페이지](https://huggingface.co/tasks/image-classification)를 확인하는 것이 좋습니다.\n",
    "\n",
    "</Tip>\n",
    "\n",
    "시작하기 전에, 필요한 모든 라이브러리가 설치되어 있는지 확인하세요:\n",
    "\n",
    "```bash\n",
    "pip install transformers datasets evaluate\n",
    "```\n",
    "\n",
    "Hugging Face 계정에 로그인하여 모델을 업로드하고 커뮤니티에 공유하는 것을 권장합니다. 메시지가 표시되면, 토큰을 입력하여 로그인하세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Food-101 데이터 세트 가져오기[[load-food101-dataset]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "🤗 Datasets 라이브러리에서 Food-101 데이터 세트의 더 작은 부분 집합을 가져오는 것으로 시작합니다. 이렇게 하면 전체 데이터 세트에 대한\n",
    "훈련에 많은 시간을 할애하기 전에 실험을 통해 모든 것이 제대로 작동하는지 확인할 수 있습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "food = load_dataset(\"ethz/food101\", split=\"train[:5000]\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "데이터 세트의 `train`을 `train_test_split` 메소드를 사용하여 훈련 및 테스트 세트로 분할하세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "food = food.train_test_split(test_size=0.2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "그리고 예시를 살펴보세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x7F52AFC8AC50>,\n",
       " 'label': 79}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "food[\"train\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "데이터 세트의 각 예제에는 두 개의 필드가 있습니다:\n",
    "\n",
    "- `image`: 식품 항목의 PIL 이미지\n",
    "- `label`: 식품 항목의 레이블 클래스\n",
    "\n",
    "모델이 레이블 ID에서 레이블 이름을 쉽게 가져올 수 있도록\n",
    "레이블 이름을 정수로 매핑하고, 정수를 레이블 이름으로 매핑하는 사전을 만드세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = food[\"train\"].features[\"label\"].names\n",
    "label2id, id2label = dict(), dict()\n",
    "for i, label in enumerate(labels):\n",
    "    label2id[label] = str(i)\n",
    "    id2label[str(i)] = label"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이제 레이블 ID를 레이블 이름으로 변환할 수 있습니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'prime_rib'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "id2label[str(79)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 전처리[[preprocess]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "다음 단계는 이미지를 텐서로 처리하기 위해 ViT 이미지 프로세서를 가져오는 것입니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoImageProcessor\n",
    "\n",
    "checkpoint = \"google/vit-base-patch16-224-in21k\"\n",
    "image_processor = AutoImageProcessor.from_pretrained(checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이미지에 몇 가지 이미지 변환을 적용하여 과적합에 대해 모델을 더 견고하게 만듭니다. 여기서 Torchvision의 [`transforms`](https://pytorch.org/vision/stable/transforms.html) 모듈을 사용하지만, 원하는 이미지 라이브러리를 사용할 수도 있습니다.\n",
    "\n",
    "이미지의 임의 부분을 크롭하고 크기를 조정한 다음, 이미지 평균과 표준 편차로 정규화하세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor\n",
    "\n",
    "normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)\n",
    "size = (\n",
    "    image_processor.size[\"shortest_edge\"]\n",
    "    if \"shortest_edge\" in image_processor.size\n",
    "    else (image_processor.size[\"height\"], image_processor.size[\"width\"])\n",
    ")\n",
    "_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "그런 다음 전처리 함수를 만들어 변환을 적용하고 이미지의 `pixel_values`(모델에 대한 입력)를 반환하세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transforms(examples):\n",
    "    examples[\"pixel_values\"] = [_transforms(img.convert(\"RGB\")) for img in examples[\"image\"]]\n",
    "    del examples[\"image\"]\n",
    "    return examples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "전체 데이터 세트에 전처리 기능을 적용하려면 🤗 Datasets `with_transform`을 사용합니다. 데이터 세트의 요소를 가져올 때 변환이 즉시 적용됩니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "food = food.with_transform(transforms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이제 [DefaultDataCollator](https://huggingface.co/docs/transformers/main/ko/main_classes/data_collator#transformers.DefaultDataCollator)를 사용하여 예제 배치를 만듭니다. 🤗 Transformers의 다른 데이터 콜레이터와 달리, `DefaultDataCollator`는 패딩과 같은 추가적인 전처리를 적용하지 않습니다."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DefaultDataCollator\n",
    "\n",
    "data_collator = DefaultDataCollator()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 평가[[evaluate]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "훈련 중에 평가 지표를 포함하면 모델의 성능을 평가하는 데 도움이 되는 경우가 많습니다.\n",
    "🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) 라이브러리로 평가 방법을 빠르게 가져올 수 있습니다. 이 작업에서는\n",
    "[accuracy](https://huggingface.co/spaces/evaluate-metric/accuracy) 평가 지표를 가져옵니다. (🤗 Evaluate [빠른 둘러보기](https://huggingface.co/docs/evaluate/a_quick_tour)를 참조하여 평가 지표를 가져오고 계산하는 방법에 대해 자세히 알아보세요):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "\n",
    "accuracy = evaluate.load(\"accuracy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "그런 다음 예측과 레이블을 `compute`에 전달하여 정확도를 계산하는 함수를 만듭니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    predictions = np.argmax(predictions, axis=1)\n",
    "    return accuracy.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이제 `compute_metrics` 함수를 사용할 준비가 되었으며, 훈련을 설정하면 이 함수로 되돌아올 것입니다."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 훈련[[train]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<Tip>\n",
    "\n",
    "[Trainer](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.Trainer)를 사용하여 모델을 미세 조정하는 방법에 익숙하지 않은 경우, [여기](https://huggingface.co/docs/transformers/main/ko/tasks/../training#train-with-pytorch-trainer)에서 기본 튜토리얼을 확인하세요!\n",
    "\n",
    "</Tip>\n",
    "\n",
    "이제 모델을 훈련시킬 준비가 되었습니다! [AutoModelForImageClassification](https://huggingface.co/docs/transformers/main/ko/model_doc/auto#transformers.AutoModelForImageClassification)로 ViT를 가져옵니다. 예상되는 레이블 수, 레이블 매핑 및 레이블 수를 지정하세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForImageClassification, TrainingArguments, Trainer\n",
    "\n",
    "model = AutoModelForImageClassification.from_pretrained(\n",
    "    checkpoint,\n",
    "    num_labels=len(labels),\n",
    "    id2label=id2label,\n",
    "    label2id=label2id,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "이제 세 단계만 거치면 끝입니다:\n",
    "\n",
    "1. [TrainingArguments](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.TrainingArguments)에서 훈련 하이퍼파라미터를 정의하세요. `image` 열이 삭제되기 때문에 미사용 열을 제거하지 않는 것이 중요합니다. `image` 열이 없으면 `pixel_values`을 생성할 수 없습니다. 이 동작을 방지하려면 `remove_unused_columns=False`로 설정하세요! 다른 유일한 필수 매개변수는 모델 저장 위치를 지정하는 `output_dir`입니다. `push_to_hub=True`로 설정하면 이 모델을 허브에 푸시합니다(모델을 업로드하려면 Hugging Face에 로그인해야 합니다). 각 에폭이 끝날 때마다, [Trainer](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.Trainer)가 정확도를 평가하고 훈련 체크포인트를 저장합니다.\n",
    "2. [Trainer](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.Trainer)에 모델, 데이터 세트, 토크나이저, 데이터 콜레이터 및 `compute_metrics` 함수와 함께 훈련 인수를 전달하세요.\n",
    "3. [train()](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.Trainer.train)을 호출하여 모델을 미세 조정하세요."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=\"my_awesome_food_model\",\n",
    "    remove_unused_columns=False,\n",
    "    eval_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    learning_rate=5e-5,\n",
    "    per_device_train_batch_size=16,\n",
    "    gradient_accumulation_steps=4,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=3,\n",
    "    warmup_steps=0.1,\n",
    "    logging_steps=10,\n",
    "    load_best_model_at_end=True,\n",
    "    metric_for_best_model=\"accuracy\",\n",
    "    push_to_hub=True,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=food[\"train\"],\n",
    "    eval_dataset=food[\"test\"],\n",
    "    processing_class=image_processor,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "훈련이 완료되면, 모든 사람이 모델을 사용할 수 있도록 [push_to_hub()](https://huggingface.co/docs/transformers/main/ko/main_classes/trainer#transformers.Trainer.push_to_hub) 메소드로 모델을 허브에 공유하세요:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<Tip>\n",
    "\n",
    "이미지 분류를 위한 모델을 미세 조정하는 자세한 예제는 다음 [PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb)을 참조하세요.\n",
    "\n",
    "</Tip>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 추론[[inference]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "좋아요, 이제 모델을 미세 조정했으니 추론에 사용할 수 있습니다!\n",
    "\n",
    "추론을 수행하고자 하는 이미지를 가져와봅시다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = load_dataset(\"ethz/food101\", split=\"validation[:10]\")\n",
    "image = ds[\"image\"][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png\" alt=\"image of beignets\"/>\n",
    "</div>\n",
    "\n",
    "미세 조정 모델로 추론을 시도하는 가장 간단한 방법은 `pipeline()`을 사용하는 것입니다. 모델로 이미지 분류를 위한 `pipeline`을 인스턴스화하고 이미지를 전달합니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'score': 0.31856709718704224, 'label': 'beignets'},\n",
       " {'score': 0.015232225880026817, 'label': 'bruschetta'},\n",
       " {'score': 0.01519392803311348, 'label': 'chicken_wings'},\n",
       " {'score': 0.013022331520915031, 'label': 'pork_chop'},\n",
       " {'score': 0.012728818692266941, 'label': 'prime_rib'}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "classifier = pipeline(\"image-classification\", model=\"my_awesome_food_model\")\n",
    "classifier(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "원한다면, `pipeline`의 결과를 수동으로 복제할 수도 있습니다:\n",
    "\n",
    "이미지를 전처리하기 위해 이미지 프로세서를 가져오고 `input`을 PyTorch 텐서로 반환합니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoImageProcessor\n",
    "import torch\n",
    "\n",
    "image_processor = AutoImageProcessor.from_pretrained(\"my_awesome_food_model\")\n",
    "inputs = image_processor(image, return_tensors=\"pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "입력을 모델에 전달하고 logits을 반환합니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForImageClassification\n",
    "\n",
    "model = AutoModelForImageClassification.from_pretrained(\"my_awesome_food_model\")\n",
    "with torch.no_grad():\n",
    "    logits = model(**inputs).logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "확률이 가장 높은 예측 레이블을 가져오고, 모델의 `id2label` 매핑을 사용하여 레이블로 변환합니다:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'beignets'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predicted_label = logits.argmax(-1).item()\n",
    "model.config.id2label[predicted_label]"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
